Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Optimisers.jl #114

Merged
merged 6 commits into from
Apr 30, 2022
Merged

Add support for Optimisers.jl #114

merged 6 commits into from
Apr 30, 2022

Conversation

lorenzoh
Copy link
Member

@lorenzoh lorenzoh commented Apr 30, 2022

Closes #112 (once done). @ToucheSir @darsnack

So this is a first draft for adding Optimisers.jl support (new optims) while keeping compatibility with optimisers in Flux.Optimise (old optims).

Passing in new optims already works, but I've broken support for old optims. Before, FluxTraining.jl was using implicit parameters with Params and Grads objects. I'm not sure how to use the old optims with explicit parameters to gradient.

I'll leave some more questions next to the code changes, some feedback from you two would be much appreciated!

src/training.jl Outdated
@@ -49,18 +49,29 @@ function step! end
function step!(learner, phase::TrainingPhase, batch)
xs, ys = batch
runstep(learner, phase, (; xs=xs, ys=ys)) do handle, state
state.grads = gradient(learner.params) do
state.ŷs = learner.model(state.xs)
state.grads, _, _= gradient(learner.model, state.xs, state.ys) do model, xs, ys
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if passing in all three is necessary here? Since we want the gradients of the model with respect to xs and ys, but I wonder if this calculates some unneeded gradients the other way around as well?

src/training.jl Outdated
end
end

# Handle both old Flux.jl and new Optimisers.jl optimisers
function _update!(optimizer::Flux.Optimise.AbstractOptimiser, params, model, grads)
update!(optimizer, model, grads)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently throws an error. For context params isa Params and grads is no longer a Grads. Is a Params even needed anymore?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment above about this.



@testset "Optimisers.jl compatibility" begin
learner = testlearner(coeff = 3, opt=Optimisers.Descent(0.001))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already passes 👍 but all the old optim tests are broken for the time being

src/training.jl Outdated
state.grads = gradient(learner.params) do
state.ŷs = learner.model(state.xs)

state.grads, = gradient(learner.model) do model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here you want to take the gradient w.r.t. learner.params when the optimizer is a Flux.Optimise.AbstractOptimiser. Conversely, if it is not, you take the gradient w.r.t. learner.model like you are now.

This is why update! below is erroring cause you need to Grads object for the old optimizers. And you can only get that with implicit params.

I think some dispatch for the gradient would be easiest. Another option is to have a utility that takes the model, the gradient w.r.t. it, and Params, then it produces a Grads to match.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I figured. Was hoping there may be a way to have the same Zygote.gradient call work but I guess not. I'll add a dispatch on the optimiser there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well what you're asking for might come eventually in a later version of Flux as part of the AD-agnostic push. So, the code might eventually get simpler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good to know. Definitely let me know then, so I can clean this up again.

src/training.jl Outdated
end
end

# Handle both old Flux.jl and new Optimisers.jl optimisers
function _update!(optimizer::Flux.Optimise.AbstractOptimiser, params, model, grads)
update!(optimizer, model, grads)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment above about this.

@lorenzoh
Copy link
Member Author

@darsnack I added the dispatch for gradient. Can you take a quick look that it looks okay before I merge?

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks right to me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use Optimisers.jl
2 participants